import json
import os
import random
import shutil
import threading
import time

from websocket import create_connection

import common
import mongo

Common = common.CommonFun()

IP = "10.0.0.153"
PORTS = [8001, 8002]
SID = 0
DB_HOST = "10.0.0.9"
DB_NAME = "xiangsu_s0_wnp"

# 数据库
DATABASE = {
    "host": DB_HOST,
    "port": 27017,
    "poolsize": 10,
    "dbname": DB_NAME,
    "dbuser": "root",
    "dbpwd": "iamciniao",
    "authdb": "admin"
}
mdb = mongo.MongoDB(DATABASE)


# 获取数据库所有表名
def getTableList():
    import pymongo
    client = pymongo.MongoClient(host=DATABASE["host"], port=DATABASE["port"])
    client.get_database("admin").authenticate("root", "iamciniao")
    db = client.get_database(DATABASE["dbname"])
    _dbTableNameList = db.list_collection_names()

    return _dbTableNameList


PORT2NUM = {}
WORKERURL2NUM = {}


# 清除压测账号
def clear():
    _dbTableNameList = getTableList()
    _users = mdb.find("userinfo", {'binduid': {"$regex": "yace"}}, fields=["_id", "uid"])
    if _users:
        _uidList = [i["uid"] for i in _users]
        for tableName in _dbTableNameList:
            mdb.delete(tableName, {"uid": {"$in": _uidList}})

    mdb.delete("apiCount")


from pb.pb_python.rpcx import rpcx_param_pb2
from pb.pb_python.user import user_api_pb2
from pb.pb_python.hero import hero_api_pb2


class Api(object):

    # 获取请求头
    def getRequest(self, moduleName, apiName):
        _request = rpcx_param_pb2.Request(
            ModuleName=moduleName,
            ApiName=apiName,
        )
        return _request

    # 登陆
    def login(self, bindUid, sid):
        _request = self.getRequest("user", "Login")
        _request.Data.Pack(user_api_pb2.LoginArgs(
            BindUid=bindUid,
            Sid=sid,
        ))

        return _request

    # 改名
    def ChangeName(self, name):
        _request = self.getRequest("user", "ChangeName")
        _request.Data.Pack(user_api_pb2.ChangeNameArgs(
                Name=name,
        ))

        return _request

    # 获取玩家数据
    def getUserInfo(self, uid):
        _request = self.getRequest("user", "GetInfo")
        _request.Data.Pack(user_api_pb2.UserGetInfoArgs(
            Uid=uid,
        ))

        return _request

    # 获取英雄列表
    def getHeroList(self):
        _request = self.getRequest("hero", "GetList")
        _request.Data.Pack(hero_api_pb2.HeroGetListArgs())

        return _request

    # 获取英雄数据
    def getHeroInfo(self, oid):
        _request = self.getRequest("hero", "GetInfo")
        _request.Data.Pack(hero_api_pb2.HeroGetInfoArgs(
            Oid=oid
        ))

        return _request

    # 跨服获取玩家数据
    def getCrossUserInfo(self, uid):
        _request = self.getRequest("cross_user", "GetInfo")
        _request.Data.Pack(user_api_pb2.UserGetInfoArgs(
            Uid=uid,
        ))

        return _request


api = Api()


class WebSocketClient(object):
    def __init__(self, ID, islogin=False):
        self.ID = ID
        self.sid = SID
        self.isLogin = islogin  # 是否只测登陆
        self.bindUid = self.getBindUid()
        self.port = self.randPort()  # 随机一个端口
        self.url = f"ws://{self.getMyUrl()}/gateway"
        self.ws = create_connection(self.url)
        self.gud = None
        self.uid = None
        self.hid2Info = {}
        self.hero_info = None
        self.item_info = None
        self.fightNum = 1
        self.state = ["接口全部执行成功", "success"]
        self.api2time = []
        self.doNum = 0
        self.workerUrl = None

    def randPort(self):
        _port = random.choice(PORTS)
        global PORT2NUM
        PORT2NUM[_port] = PORT2NUM.get(_port, 0) + 1
        return _port

    def getMyUrl(self):
        return f"{IP}:{self.port}"

    def getBindUid(self):
        return f"yace_{int(time.time())}_{self.ID}"

    def apiList(self):
        return [
            {"func": api.login, "args": [self.bindUid, self.sid], "sleep": 0},
            {"func": api.ChangeName, "args": [f"压测{self.ID}号"], "sleep": 2},
            {"func": api.getUserInfo, "args": [self.uid], "sleep": 1},
            {"func": api.getHeroList, "args": [], "sleep": 2},
            {"func": api.getHeroInfo, "args": [], "sleep": 1},
            {"func": api.getCrossUserInfo, "args": [], "sleep": 0}
        ]

    def recv(self):
        _data = self.ws.recv()
        try:
            _data = json.loads(_data)
        except:
            pass

        return _data

    def over(self):
        if not self.gud:
            print(f"over->【{self.bindUid}】{self.state[0]} 进度({self.doNum}/{len(self.apiList())}){self.state[1]}")
        else:
            print(
                f"over->{self.gud.Uid}【{self.bindUid}】{self.state[0]} 进度({self.doNum}/{len(self.apiList())}){self.state[1]}")
        with open(f"yace_baogao/{self.bindUid}.txt", "w") as f:
            f.write(json.dumps(self.api2time, ensure_ascii=False, indent=2))

        global userSet
        userSet.add(self.bindUid)

        while len(userSet) < testNum:
            time.sleep(0.5)

        # time.sleep(10)

    # 解序列接口响应
    def deCodeApi(self, response):
        _api2Response = {
            # 登陆
            "user.Login": user_api_pb2.LoginRes,
            # 获取玩家数据
            "user.GetInfo": user_api_pb2.UserGetInfoRes,
            # 获取英雄列表
            "hero.GetList": hero_api_pb2.HeroGetListRes,
            # 获取英雄数据
            "hero.GetInfo": hero_api_pb2.HeroGetInfoRes,
            # 获取跨服玩家数据
            "cross_user.GetInfo": user_api_pb2.UserGetInfoRes,
        }

        _apiResponse = _api2Response.get(response.ResponseApi)
        if not _apiResponse:
            return

        _res = _apiResponse()
        response.Data.Unpack(_res)

        return _res

    # 解析pb消息
    def parseMessage(self, msg):
        try:
            _response = rpcx_param_pb2.Response()
            _response.ParseFromString(msg)
        except:
            return None, None

        return _response, self.deCodeApi(_response)

    def doRecv(self):
        _sTime = time.time()
        _maxTime = 60 * 10

        # 登陆需要得到返回值才能继续调用
        while True:
            try:
                _data = self.recv()
            except Exception as e:
                return f"服务器异常{e}"

            _response, _apiRes = self.parseMessage(_data)
            if _response:
                api = _response.ResponseApi
                if _response.ErrorMsg:
                    return f"{api}==> {_response.ErrorMsg}"
                if api:
                    if api == "user.Login":
                        if hasattr(_apiRes, "WorkerUrl"):
                            # 记录分配的worker
                            self.workerUrl = _apiRes.WorkerUrl
                            global WORKERURL2NUM
                            WORKERURL2NUM[self.workerUrl] = WORKERURL2NUM.get(self.workerUrl, 0) + 1

                            self.gud = _apiRes.Gud
                            self.uid = self.gud.Uid
                            print(f"{self.bindUid}【{self.gud.Name}】登陆成功~")
                            break
                        else:
                            pass
                    else:
                        if api == "hero.GetList":
                            self.hid2Info = {h.Hid: h for h in _apiRes.HeroList}
                        break
                else:
                    _msg = json.loads(_response.StringMsg)
                    if "await" in _msg:
                        print(f"登陆需要排队等待{_msg['await']}秒")
                    else:
                        print(f"收到服务端推送-->{_msg}")

            if time.time() - _sTime >= _maxTime:
                return "登陆超时"

    # 获取接口参数
    def getAPiArgs(self, apiInfo):
        _func = apiInfo["func"]
        if _func in [api.getUserInfo, api.getCrossUserInfo] :
            return [self.uid]
        if _func == api.getHeroInfo:
            return [self.hid2Info[list(self.hid2Info.keys())[0]].Id]

        return apiInfo.get("args", [])

    def doApi(self):
        for api in self.apiList():
            self.doNum += 1

            if api == "islogin":
                # 只测登陆流程
                if self.isLogin:
                    break
                continue

            _sleepTime = api["sleep"]
            # 等待
            # if _sleepTime > 0:
            #     time.sleep(_sleepTime)

            _func = api["func"]
            _pb = api["func"](*self.getAPiArgs(api), **api.get("kwargs", {}))
            _api = f"{_pb.ModuleName}.{_pb.ApiName}"
            # send msg
            self.ws.send(_pb.SerializeToString())
            # print(f"正在请求-->【{_api}】")
            _sendTime = time.time()
            # 拿到返回值才继续调用
            _res = self.doRecv()
            self.api2time.append(f"{_api} 耗时：{(time.time() - _sendTime) * 1000}")
            if _res:
                self.state = ["接口调用终止，error==>", _res + f"【{_api}】"]
                break

    def run(self):
        self.doApi()
        self.over()


# 接口测试
class ApiTest(WebSocketClient):
    def __init__(self, bindUid, islogin=False, runTime=0):
        WebSocketClient.__init__(self, bindUid, islogin=islogin)
        self.runTime = runTime
        self.bindUid = bindUid

    def doApi(self):
        for i in self.apiList():
            self.doNum += 1

            if isinstance(i, int):
                _sleepTime = self.runTime - Common.NOW()
                if _sleepTime > 0:
                    time.sleep(_sleepTime)

                continue

            _sleepTime = i["sleep"]
            # 等待
            if _sleepTime > 0:
                time.sleep(_sleepTime)

            _api = i["p"]
            # send msg
            _apiData = self.fmtMsg(i)
            self.ws.send(_apiData)

            if _api == "login.login" or 1:
                _sendTime = time.time()
                # 拿到返回值才继续调用
                _res = self.doRecv(_api)
                self.api2time.append(f"{_api} 耗时：{(time.time() - _sendTime) * 1000}")
                if _res:
                    self.state = ["接口调用终止，error==>", _res + f"【{_apiData}】"]
                    break

    def apiList(self):
        return [
            {"func": api.login, "args": [self.bindUid, self.sid], "sleep": 0},
            {"func": api.getHeroList, "args": [], "sleep": 1}
        ]

    def run(self):
        self.doApi()
        self.over()


def doRun(i, maxSemaphore, islogin=False):
    with maxSemaphore:
        WebSocketClient(i, islogin=islogin).run()


def doApiRun(i, maxSemaphore, runTime):
    with maxSemaphore:
        ApiTest(i, runTime=runTime).run()


userNum = 0
userSet = set()


def start(num, islogin=False):
    """
    :param num: 请求数
    :param islogin: 是否只测登陆接口
    :return:
    """

    _path = os.path.dirname(__file__)
    _chkFolder = _path + "/yace_baogao"
    if os.path.exists(_chkFolder):
        shutil.rmtree(_chkFolder)
    os.mkdir(_chkFolder)

    _sTime = time.time()
    # 最大并发数
    max_connections = num
    _maxSemaphore = threading.BoundedSemaphore(max_connections)
    # start
    _taskList = []
    for i in range(1, num + 1):
        global userNum
        userNum += 1
        t = threading.Thread(target=doRun, args=(userNum, _maxSemaphore, islogin))
        _taskList.append(t)
        t.start()
    for i in _taskList:
        i.join()

    from pprint import pprint
    pprint(f"端口分布数量：{PORT2NUM}")
    for url, num in WORKERURL2NUM.items():
        print(f"{url}分配的数量为：{num}")
    print(f"总耗时：{(time.time() - _sTime)}秒")


def apiTest(num, runTime):
    _sTime = time.time()
    # 最大并发数
    max_connections = num
    _maxSemaphore = threading.BoundedSemaphore(max_connections)
    # start
    _taskList = []

    _userList = mdb.find("userinfo", {"binduid": {"$ne": "wnp001"}}, sort=[["mapid", -1]], limit=num,
                           fields=["_id", "binduid"])
    for i in _userList:
        global userNum
        userNum += 1
        t = threading.Thread(target=doApiRun, args=(i["binduid"], _maxSemaphore, runTime))
        _taskList.append(t)
        t.start()
    for i in _taskList:
        i.join()

    from pprint import pprint
    pprint(f"端口分布数量：{PORT2NUM}")
    for url, num in WORKERURL2NUM.items():
        print(f"{url}分配的数量为：{num}")
    print(f"总耗时：{(time.time() - _sTime)}秒")


testNum = 1

if __name__ == '__main__':
    clear()
    start(testNum, islogin=False)
    # apiTest(500, Common.NOW() + 150)
